In this section, we try to do customer segmentation with RFM (Recency, Frequency and Monetary Value) scores. Since we have the data for the whole year of 2017, we seperate it into two sets: January-September and Octobor-December. Then we will do customer segmentation for the first 9 months and predict the customer life time value (LTV) for the last 3 months.
Customer Segmentation (First 9 Months)
We calculate the recency (how many days past since the last transaction), frequency (how many transactions within 9 months) and log revenue for all transactions. Then we do clustering for these 3 features so that we have 3 clustering features. We re-order the cluster order the make the biggest cluster number represents the best results. Thus, we calculate the overall score by summing up these 3 scores. Then we segment the customer into 4 groups according to the overall scores.
LTV Prediction (Last 3 Months)
First, we use the log revenue of the last 3 months to do a LTV clustering. Then we use the 9-month features to predict the 3-month LTV.
from __future__ import division
from datetime import datetime, timedelta,date
import pandas as pd
%matplotlib inline
from sklearn.metrics import classification_report,confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.cluster import KMeans
import chart_studio.plotly as py
import plotly.offline as pyoff
import plotly.graph_objs as go
import xgboost as xgb
from sklearn.model_selection import KFold, cross_val_score, train_test_split
import warnings
warnings.filterwarnings("ignore")
pyoff.init_notebook_mode()
# load data
df = pd.read_pickle('2017_clean.pkl')
train = df.copy()
# keep related information and change data type
train = train[['fullVisitorId', 'date', 'totals.totalTransactionRevenue']]
train["date"] = pd.to_datetime(train["date"], format="%Y%m%d") # seting the column as pandas datetime
train['totals.totalTransactionRevenue']=train['totals.totalTransactionRevenue'].astype('float')
train.head()
train['date'].describe()
# select data from Jan to Sept
train_9 = train[(train.date < date(2017,10,1)) & (train.date >= date(2017,1,1))].reset_index(drop=True)
# select data from Oct to Dec
train_12 = train[(train.date >= date(2017,10,1)) & (train.date <= date(2017,12,31))].reset_index(drop=True)
train_9['date'].describe()
train_12['date'].describe()
customer = pd.DataFrame(train['fullVisitorId'].unique())
customer.columns = ['fullVisitorId']
max_purchase = train_9.groupby('fullVisitorId').date.max().reset_index()
max_purchase.columns = ['fullVisitorId','MaxPurchaseDate']
max_purchase['Recency'] = (max_purchase['MaxPurchaseDate'].max() - max_purchase['MaxPurchaseDate']).dt.days
user_recency = pd.merge(customer, max_purchase[['fullVisitorId','Recency']], on='fullVisitorId')
user_recency.head()
user_recency.Recency.describe()
plot_data = [
go.Histogram(
x=user_recency['Recency']
)
]
plot_layout = go.Layout(
title='Recency'
)
fig = go.Figure(data=plot_data, layout=plot_layout)
pyoff.iplot(fig)
# Elbow Method
sse={}
recency = user_recency[['Recency']]
for k in range(1, 10):
kmeans = KMeans(n_clusters=k, max_iter=1000).fit(recency)
recency["clusters"] = kmeans.labels_
sse[k] = kmeans.inertia_
plt.figure()
plt.plot(list(sse.keys()), list(sse.values()))
plt.xlabel("Number of cluster")
plt.show()
# let's try 4 first
kmeans = KMeans(n_clusters=4)
kmeans.fit(user_recency[['Recency']])
user_recency['RecencyCluster'] = kmeans.predict(user_recency[['Recency']])
user_recency.groupby('RecencyCluster')['Recency'].describe()
def order_cluster(cluster_field_name, target_field_name,df,ascending):
new_cluster_field_name = 'new_' + cluster_field_name
df_new = df.groupby(cluster_field_name)[target_field_name].mean().reset_index()
df_new = df_new.sort_values(by=target_field_name,ascending=ascending).reset_index(drop=True)
df_new['index'] = df_new.index
df_final = pd.merge(df,df_new[[cluster_field_name,'index']], on=cluster_field_name)
df_final = df_final.drop([cluster_field_name],axis=1)
df_final = df_final.rename(columns={"index":cluster_field_name})
return df_final
# call the funtion above to order cluster
user_recency = order_cluster('RecencyCluster', 'Recency', user_recency,False)
user_recency.groupby('RecencyCluster')['Recency'].describe()
user_recency.groupby('RecencyCluster')['Recency'].agg(['count','mean']).reset_index()
user_recency.groupby('RecencyCluster')['Recency'].hist()
plt.xlabel('Recency')
plt.ylabel('Count')
plt.title('Recency_Cluster')
user_frequency = train_9.groupby('fullVisitorId').date.count().reset_index()
user_frequency.columns = ['fullVisitorId','Frequency']
user_frequency.head()
user = pd.merge(user_recency, user_frequency, on='fullVisitorId')
user.head()
user.Frequency.describe()
plot_data = [
go.Histogram(
x=user['Frequency']
)
]
plot_layout = go.Layout(
title='Frequency'
)
fig = go.Figure(data=plot_data, layout=plot_layout)
pyoff.iplot(fig)
# Elbow Method
sse={}
frequency = user[['Frequency']]
for k in range(1, 15):
kmeans = KMeans(n_clusters=k, max_iter=1000).fit(frequency)
frequency["clusters"] = kmeans.labels_
sse[k] = kmeans.inertia_
plt.figure()
plt.plot(list(sse.keys()), list(sse.values()))
plt.xlabel("Number of cluster")
plt.show()
kmeans = KMeans(n_clusters=4)
kmeans.fit(user[['Frequency']])
user['FrequencyCluster'] = kmeans.predict(user[['Frequency']])
user = order_cluster('FrequencyCluster', 'Frequency',user,True)
user.groupby('FrequencyCluster')['Frequency'].describe()
user.groupby('FrequencyCluster')['Frequency'].agg(['count','mean']).reset_index()
revenue = train_9.groupby('fullVisitorId')['totals.totalTransactionRevenue'].sum().reset_index()
revenue.columns = ['fullVisitorId', 'Revenue']
revenue['logRevenue'] = np.log(1+revenue['Revenue'])
revenue = revenue.drop(['Revenue'], axis=1)
revenue.head()
user = pd.merge(user, revenue, on='fullVisitorId')
user.head()
user.logRevenue.describe()
# the elbow method show 2
kmeans = KMeans(n_clusters=2)
kmeans.fit(user[['logRevenue']])
user['logRevenueCluster'] = kmeans.predict(user[['logRevenue']])
user = order_cluster('logRevenueCluster', 'logRevenue',user,True)
user.groupby('logRevenueCluster')['logRevenue'].describe()
user.groupby('logRevenueCluster')['logRevenue'].agg(['count','mean']).reset_index()
user.head()
user['OverallScore'] = user['RecencyCluster'] + user['FrequencyCluster'] + user['logRevenueCluster']
user.groupby('OverallScore')['Recency','Frequency','logRevenue'].mean().reset_index()
Looks like overall score 7 is the highest value customer.
user['Segment'] = 'Low-Value'
user.loc[user['OverallScore']>0,'Segment'] = 'Mid1-Value'
user.loc[user['OverallScore']>3,'Segment'] = 'Mid2-Value'
user.loc[user['OverallScore']>5,'Segment'] = 'High-Value'
sg = user.groupby('Segment')['logRevenue'].agg(['count','mean']).sort_values(by=['mean']).reset_index()
sg.columns = ['Segment','count','mean_logRevenue_9']
sg
user_sample = user.sample(20000)
# tx_graph = user_sample
# plot_data = [
# go.Scatter(
# x=tx_graph.query("Segment == 'Low-Value'")['Recency'],
# y=tx_graph.query("Segment == 'Low-Value'")['Frequency'],
# mode='markers',
# name='Low',
# marker= dict(size= 7,
# line= dict(width=1),
# color= 'blue',
# opacity= 0.8
# )
# ),
# go.Scatter(
# x=tx_graph.query("Segment == 'Mid1-Value'")['Recency'],
# y=tx_graph.query("Segment == 'Mid1-Value'")['Frequency'],
# mode='markers',
# name='Mid1',
# marker= dict(size= 9,
# line= dict(width=1),
# color= 'green',
# opacity= 0.5
# )
# ),
# go.Scatter(
# x=tx_graph.query("Segment == 'Mid2-Value'")['Recency'],
# y=tx_graph.query("Segment == 'Mid2-Value'")['Frequency'],
# mode='markers',
# name='Mid2',
# marker= dict(size= 9,
# line= dict(width=1),
# color= 'orange',
# opacity= 0.5
# )
# ),
# go.Scatter(
# x=tx_graph.query("Segment == 'High-Value'")['Recency'],
# y=tx_graph.query("Segment == 'High-Value'")['Frequency'],
# mode='markers',
# name='High',
# marker= dict(size= 11,
# line= dict(width=1),
# color= 'red',
# opacity= 0.9
# )
# ),
# ]
# plot_layout = go.Layout(
# yaxis= {'title': "Frequency"},
# xaxis= {'title': "Recency"},
# title='Segments'
# )
# fig = go.Figure(data=plot_data, layout=plot_layout)
# pyoff.iplot(fig)
# tx_graph = user_sample
# plot_data = [
# go.Scatter(
# x=tx_graph.query("Segment == 'Low-Value'")['Recency'],
# y=tx_graph.query("Segment == 'Low-Value'")['logRevenue'],
# mode='markers',
# name='Low',
# marker= dict(size= 7,
# line= dict(width=1),
# color= 'blue',
# opacity= 0.8
# )
# ),
# go.Scatter(
# x=tx_graph.query("Segment == 'Mid1-Value'")['Recency'],
# y=tx_graph.query("Segment == 'Mid1-Value'")['logRevenue'],
# mode='markers',
# name='Mid1',
# marker= dict(size= 9,
# line= dict(width=1),
# color= 'green',
# opacity= 0.5
# )
# ),
# go.Scatter(
# x=tx_graph.query("Segment == 'Mid2-Value'")['Recency'],
# y=tx_graph.query("Segment == 'Mid2-Value'")['logRevenue'],
# mode='markers',
# name='Mid2',
# marker= dict(size= 9,
# line= dict(width=1),
# color= 'orange',
# opacity= 0.5
# )
# ),
# go.Scatter(
# x=tx_graph.query("Segment == 'High-Value'")['Recency'],
# y=tx_graph.query("Segment == 'High-Value'")['logRevenue'],
# mode='markers',
# name='High',
# marker= dict(size= 11,
# line= dict(width=1),
# color= 'red',
# opacity= 0.9
# )
# ),
# ]
# plot_layout = go.Layout(
# yaxis= {'title': "logRevenue"},
# xaxis= {'title': "Recency"},
# title='Segments'
# )
# fig = go.Figure(data=plot_data, layout=plot_layout)
# pyoff.iplot(fig)
# tx_graph = user_sample
# plot_data = [
# go.Scatter(
# x=tx_graph.query("Segment == 'Low-Value'")['Frequency'],
# y=tx_graph.query("Segment == 'Low-Value'")['logRevenue'],
# mode='markers',
# name='Low',
# marker= dict(size= 7,
# line= dict(width=1),
# color= 'blue',
# opacity= 0.8
# )
# ),
# go.Scatter(
# x=tx_graph.query("Segment == 'Mid1-Value'")['Frequency'],
# y=tx_graph.query("Segment == 'Mid1-Value'")['logRevenue'],
# mode='markers',
# name='Mid1',
# marker= dict(size= 9,
# line= dict(width=1),
# color= 'green',
# opacity= 0.5
# )
# ),
# go.Scatter(
# x=tx_graph.query("Segment == 'Mid2-Value'")['Frequency'],
# y=tx_graph.query("Segment == 'Mid2-Value'")['logRevenue'],
# mode='markers',
# name='Mid2',
# marker= dict(size= 9,
# line= dict(width=1),
# color= 'orange',
# opacity= 0.5
# )
# ),
# go.Scatter(
# x=tx_graph.query("Segment == 'High-Value'")['Frequency'],
# y=tx_graph.query("Segment == 'High-Value'")['logRevenue'],
# mode='markers',
# name='High',
# marker= dict(size= 11,
# line= dict(width=1),
# color= 'red',
# opacity= 0.9
# )
# ),
# ]
# plot_layout = go.Layout(
# yaxis= {'title': "logRevenue"},
# xaxis= {'title': "Frequency"},
# title='Segments'
# )
# fig = go.Figure(data=plot_data, layout=plot_layout)
# pyoff.iplot(fig)
user.head()
train_12.head()
revenue_12 = train_12.groupby('fullVisitorId')['totals.totalTransactionRevenue'].sum().reset_index()
revenue_12.columns = ['fullVisitorId','Revenue_12']
revenue_12['logRevenue_12'] = np.log(1+revenue_12['Revenue_12'])
revenue_12 = revenue_12.drop(['Revenue_12'], axis=1)
revenue_12.head()
# merge 9-month features with 3-month revenue
merge = pd.merge(user, revenue_12, on='fullVisitorId', how='left')
merge.head()
merge = merge.fillna(0)
sg_12 = merge.groupby('Segment')['logRevenue_12'].agg(['count','mean']).sort_values(by=['mean']).reset_index()
sg_12.columns = ['Segment','count','mean_logRevenue_12']
sg_12
merge.shape
kmeans = KMeans(n_clusters=4)
kmeans.fit(merge[['logRevenue_12']])
merge['LTVCluster'] = kmeans.predict(merge[['logRevenue_12']])
merge = order_cluster('LTVCluster', 'logRevenue_12',merge,True)
merge.groupby('LTVCluster')['logRevenue_12'].describe()
ltv = merge.groupby('LTVCluster')['logRevenue_12'].agg(['count','mean']).reset_index()
ltv.columns = ['LTVCluster','count','mean_logRevenue_12']
ltv
cluster = merge.copy()
cluster.head()
# dummy Segment
cluster['Segment_High']=0
cluster.loc[cluster['Segment']=='High-Value','Segment_High']=1
cluster['Segment_Mid1']=0
cluster.loc[cluster['Segment']=='Mid1-Value','Segment_Mid1']=1
cluster['Segment_Mid2']=0
cluster.loc[cluster['Segment']=='Mid2-Value','Segment_Mid2']=1
cluster['Segment_Low']=0
cluster.loc[cluster['Segment']=='Low-Value','Segment_Low']=1
cluster = cluster.drop(['Segment'],axis=1)
cluster.head()
# corelation with 3-month (Oct-Dec) LTVCluster
corr_matrix = cluster.corr()
corr_matrix['LTVCluster'].sort_values(ascending=False)
X = cluster.drop(['LTVCluster','logRevenue_12','fullVisitorId'],axis=1)
y = cluster['LTVCluster']
X_columns = cluster.drop(['LTVCluster','logRevenue_12','fullVisitorId'],axis=1).columns
y_column = ['LTVCluster']
import lightgbm as lgb
from imblearn.over_sampling import SMOTE
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=56)
# using SMOTE for unbalanced data
X_train, y_train= SMOTE().fit_resample(X_train, y_train)
# fitting lightGBM
params = {"objective" : "multiclass",
"num_class": 4,
"metric" : "multi_error",
"num_leaves" : 30,
"min_child_weight" : 50,
"learning_rate" : 0.1,
"bagging_fraction" : 0.7,
"feature_fraction" : 0.7,
"reg_alpha": 0.15,
"reg_lambda": 0.15,
"min_child_weight": 50,
"bagging_seed" : 420,
"verbosity" : -1
}
lg_train = lgb.Dataset(X_train, label=y_train)
lg_test = lgb.Dataset(X_test, label=y_test)
model = lgb.train(params, lg_train, 1000, valid_sets=[lg_test], early_stopping_rounds=50, verbose_eval=100)
y_pred = model.predict(X_test)
y_pred[0]
# change y_pred format by select the index of maximum probability
y_pred_list = [int(np.where(i == np.amax(i))[0]) for i in y_pred]
y_pred_ar = np.array(y_pred_list)
y_pred_ar
print(classification_report(y_test, y_pred_ar))
ft = df.copy()
ft["date"] = pd.to_datetime(ft["date"], format="%Y%m%d")
# select data from Jan to Sept
ft = ft[(ft.date < date(2017,10,1)) & (ft.date >= date(2017,1,1))].reset_index(drop=True)
ft.shape
ft1 = ft[['fullVisitorId','geoNetwork.continent']]
ft1 = ft1.join(pd.get_dummies(ft1['geoNetwork.continent'])).drop(['(not set)'],axis = 1).drop(['geoNetwork.continent'], axis=1)
ft1 = ft1.groupby(['fullVisitorId']).agg(['max'])
ft1.columns = ['_'.join(col).strip() for col in ft1.columns.values]
ft1 = ft1.reset_index()
ft1.head()
# merge continent features with cluster
ft1_df = pd.merge(cluster, ft1, on='fullVisitorId')
ft1_df.head()
X = ft1_df.drop(['LTVCluster','logRevenue_12','fullVisitorId'],axis=1)
y = ft1_df['LTVCluster']
X_columns = ft1_df.drop(['LTVCluster','logRevenue_12','fullVisitorId'],axis=1).columns
y_column = ['LTVCluster']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=56)
# using SMOTE for unbalanced data
X_train, y_train = SMOTE().fit_resample(X_train, y_train)
# change X, y to dataframe so that we can keep column names
X_train = pd.DataFrame(X_train, columns=X_columns)
y_train = pd.DataFrame(y_train, columns=y_column)
# take a look at y after SMOTE, 4 clusters have the same samples now
y_train.hist()
# fitting lightGBM
params = {"objective" : "multiclass",
"num_class": 4,
"metric" : "multi_error",
"num_leaves" : 30,
"min_child_weight" : 50,
"learning_rate" : 0.1,
"bagging_fraction" : 0.7,
"feature_fraction" : 0.7,
"reg_alpha": 0.15,
"reg_lambda": 0.15,
"min_child_weight": 50,
"bagging_seed" : 420,
"verbosity" : -1
}
lg_train = lgb.Dataset(X_train, label=y_train)
lg_test = lgb.Dataset(X_test, label=y_test)
model = lgb.train(params, lg_train, 1000, valid_sets=[lg_test], early_stopping_rounds=50, verbose_eval=100)
y_pred = model.predict(X_test)
y_pred[0]
# change y_pred format by select the index of maximum probability
y_pred_list = [int(np.where(i == np.amax(i))[0]) for i in y_pred]
y_pred_ar = np.array(y_pred_list)
y_pred_ar
print(classification_report(y_test, y_pred_ar))
fig, ax = plt.subplots(figsize=(8,5))
lgb.plot_importance(model, height=0.8, ax=ax)
ax.grid(False)
plt.ylabel('Feature', size=12)
plt.xlabel('Importance', size=12)
plt.title("Importance of the Features of our LightGBM Model", fontsize=12)
plt.show()
ft2 = ft[['fullVisitorId','geoNetwork.country']]
ft2['country_US']=0
ft2.loc[ft2['geoNetwork.country']=='United States','country_US']=1
ft2 = ft2.drop(['geoNetwork.country'],axis=1)
ft2.head()
# check whether a customer only stay in US/not in US, but not for both
us = ft2.groupby(['fullVisitorId']).agg(['sum'])
us.columns = ['_'.join(col).strip() for col in us.columns.values]
us = us.reset_index()
us_merge = pd.merge(user_frequency, us, on='fullVisitorId')
us_merge['us_only'] = (us_merge['Frequency']==us_merge['country_US_sum'])
us_merge.loc[us_merge['Frequency']== 1,'us_only']=True
us_merge.loc[us_merge['country_US_sum']== 0,'us_only']=True
us_merge['us_only'] = us_merge['us_only']*1
# number of customer have been both in US and not in US
print(us_merge[us_merge['us_only']==0]['fullVisitorId'].count())
# presentage of customer have been both in US and not in US
print(us_merge[us_merge['us_only']==0]['fullVisitorId'].count()/us_merge.fullVisitorId.count())
Since majority of people (99.8%) will staying in US not not in US through all the trasaction, we can consider whether in US as a stable attribute for a customer.
ft2 = ft2.groupby(['fullVisitorId']).agg(['max'])
ft2.columns = ['_'.join(col).strip() for col in ft2.columns.values]
ft2 = ft2.reset_index()
ft2.columns = ['fullVisitorId', 'country_US']
ft2.head()
# merge continent features with cluster
ft2_df = pd.merge(ft1_df, ft2, on='fullVisitorId')
ft2_df.head()
X = ft2_df.drop(['LTVCluster','logRevenue_12','fullVisitorId'],axis=1)
y = ft2_df['LTVCluster']
X_columns = ft2_df.drop(['LTVCluster','logRevenue_12','fullVisitorId'],axis=1).columns
y_column = ['LTVCluster']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=56)
# using SMOTE for unbalanced data
X_train, y_train = SMOTE().fit_resample(X_train, y_train)
# change X, y to dataframe so that we can keep column names
X_train = pd.DataFrame(X_train, columns=X_columns)
y_train = pd.DataFrame(y_train, columns=y_column)
# take a look at y after SMOTE, 4 clusters have the same samples now
y_train.hist()
# fitting lightGBM
params = {"objective" : "multiclass",
"num_class": 4,
"metric" : "multi_error",
"num_leaves" : 30,
"min_child_weight" : 50,
"learning_rate" : 0.1,
"bagging_fraction" : 0.7,
"feature_fraction" : 0.7,
"reg_alpha": 0.15,
"reg_lambda": 0.15,
"min_child_weight": 50,
"bagging_seed" : 420,
"verbosity" : -1
}
lg_train = lgb.Dataset(X_train, label=y_train)
lg_test = lgb.Dataset(X_test, label=y_test)
model = lgb.train(params, lg_train, 1000, valid_sets=[lg_test], early_stopping_rounds=50, verbose_eval=100)
y_pred = model.predict(X_test)
y_pred[0]
# change y_pred format by select the index of maximum probability
y_pred_list = [int(np.where(i == np.amax(i))[0]) for i in y_pred]
y_pred_ar = np.array(y_pred_list)
y_pred_ar
print(classification_report(y_test, y_pred_ar))
fig, ax = plt.subplots(figsize=(8,5))
lgb.plot_importance(model, height=0.8, ax=ax)
ax.grid(False)
plt.ylabel('Feature', size=12)
plt.xlabel('Importance', size=12)
plt.title("Importance of the Features of our LightGBM Model", fontsize=12)
plt.show()
# merge continent features with cluster
ft_us = pd.merge(cluster, ft2, on='fullVisitorId')
ft_us.head()
X = ft_us.drop(['LTVCluster','logRevenue_12','fullVisitorId'],axis=1)
y = ft_us['LTVCluster']
X_columns = ft_us.drop(['LTVCluster','logRevenue_12','fullVisitorId'],axis=1).columns
y_column = ['LTVCluster']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=56)
# using SMOTE for unbalanced data
X_train, y_train= SMOTE().fit_resample(X_train, y_train)
# change X, y to dataframe so that we can keep column names
X_train = pd.DataFrame(X_train, columns=X_columns)
y_train = pd.DataFrame(y_train, columns=y_column)
# take a look at y after SMOTE, 4 clusters have the same samples now
y_train.hist()
# fitting lightGBM
params = {"objective" : "multiclass",
"num_class": 4,
"metric" : "multi_error",
"num_leaves" : 30,
"min_child_weight" : 50,
"learning_rate" : 0.1,
"bagging_fraction" : 0.7,
"feature_fraction" : 0.7,
"reg_alpha": 0.15,
"reg_lambda": 0.15,
"min_child_weight": 50,
"bagging_seed" : 420,
"verbosity" : -1
}
lg_train = lgb.Dataset(X_train, label=y_train)
lg_test = lgb.Dataset(X_test, label=y_test)
model = lgb.train(params, lg_train, 1000, valid_sets=[lg_test], early_stopping_rounds=50, verbose_eval=100)
y_pred = model.predict(X_test)
y_pred[0]
# change y_pred format by select the index of maximum probability
y_pred_list = [int(np.where(i == np.amax(i))[0]) for i in y_pred]
y_pred_ar = np.array(y_pred_list)
y_pred_ar
print(classification_report(y_test, y_pred_ar))
fig, ax = plt.subplots(figsize=(8,5))
lgb.plot_importance(model, height=0.8, ax=ax)
ax.grid(False)
plt.ylabel('Feature', size=12)
plt.xlabel('Importance', size=12)
plt.title("Importance of the Features of our LightGBM Model", fontsize=12)
plt.show()
Adding continent and country_us has the almost the same result as adding country_us only. To avoid overfitting, we will select the third model, adding contry_us only.
from sklearn import metrics
from sklearn.utils.multiclass import unique_labels
class_names=[0,1,2,3] # name of classes
def plot_confusion_matrix(y_true, y_pred, classes,
normalize=False,
title=None,
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if not title:
if normalize:
title = 'Normalized confusion matrix'
else:
title = 'Confusion matrix, without normalization'
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Only use the labels that appear in the data
classes = classes
if normalize:
cm = cm.astype('float') / cm.sum(axis=0)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
# We want to show all ticks...
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes, yticklabels=classes,
title=title,
ylabel='True label',
xlabel='Predicted label')
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
return ax
np.set_printoptions(precision=2)
# Plot non-normalized confusion matrix
plot_confusion_matrix(y_test, y_pred_ar, classes=class_names,
title='Confusion matrix, without normalization')
# Plot normalized confusion matrix
plot_confusion_matrix(y_test, y_pred_ar, classes=class_names, normalize=True,
title='Normalized confusion matrix')
plt.show()